f46874
@@ -27,8 +27,6 @@
 import java.util.Set;
 import java.util.Stack;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 import org.apache.hadoop.hive.common.JavaUtils;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.exec.AppMasterEventOperator;
@@ -38,16 +36,13 @@
 import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
 import org.apache.hadoop.hive.ql.exec.GroupByOperator;
 import org.apache.hadoop.hive.ql.exec.JoinOperator;
-import org.apache.hadoop.hive.ql.exec.LateralViewJoinOperator;
 import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
 import org.apache.hadoop.hive.ql.exec.MuxOperator;
 import org.apache.hadoop.hive.ql.exec.Operator;
 import org.apache.hadoop.hive.ql.exec.OperatorFactory;
 import org.apache.hadoop.hive.ql.exec.OperatorUtils;
-import org.apache.hadoop.hive.ql.exec.PTFOperator;
 import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
 import org.apache.hadoop.hive.ql.exec.TezDummyStoreOperator;
-import org.apache.hadoop.hive.ql.exec.UDTFOperator;
 import org.apache.hadoop.hive.ql.lib.Node;
 import org.apache.hadoop.hive.ql.lib.NodeProcessor;
 import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
@@ -64,8 +59,8 @@
 import org.apache.hadoop.hive.ql.plan.OperatorDesc;
 import org.apache.hadoop.hive.ql.plan.Statistics;
 import org.apache.hadoop.util.ReflectionUtils;
-
-import com.google.common.collect.ImmutableSet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * ConvertJoinMapJoin is an optimization that replaces a common join
@@ -78,16 +73,6 @@
 
   private static final Logger LOG = LoggerFactory.getLogger(ConvertJoinMapJoin.class.getName());
 
-  @SuppressWarnings({ "unchecked", "rawtypes" })
-  private static final Set<Class<? extends Operator<?>>> COSTLY_OPERATORS =
-          new ImmutableSet.Builder()
-                  .add(CommonJoinOperator.class)
-                  .add(GroupByOperator.class)
-                  .add(LateralViewJoinOperator.class)
-                  .add(PTFOperator.class)
-                  .add(ReduceSinkOperator.class)
-                  .add(UDTFOperator.class)
-                  .build();
 
   @Override
   /*
@@ -146,9 +131,11 @@
       }
     }
 
-    LOG.info("Convert to non-bucketed map join");
     // check if we can convert to map join no bucket scaling.
-    mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, 1);
+    LOG.info("Convert to non-bucketed map join");
+    if (numBuckets != 1) {
+      mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, 1);
+    }
     if (mapJoinConversionPos < 0) {
       // we are just converting to a common merge join operator. The shuffle
       // join in map-reduce case.
@@ -557,8 +544,8 @@
public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c
         HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD);
 
     int bigTablePosition = -1;
-    // number of costly ops (Join, GB, PTF/Windowing, TF) below the big input
-    int bigInputNumberCostlyOps = -1;
+    // big input cumulative row count
+    long bigInputCumulativeCardinality = -1L;
     // stats of the big input
     Statistics bigInputStat = null;
 
@@ -602,18 +589,27 @@
public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c
         }
       }
 
-      int currentInputNumberCostlyOps = foundInputNotFittingInMemory ?
-              -1 : OperatorUtils.countOperatorsUpstream(parentOp, COSTLY_OPERATORS);
+      long currentInputCumulativeCardinality;
+      if (foundInputNotFittingInMemory) {
+        currentInputCumulativeCardinality = -1L;
+      } else {
+        Long cardinality = computeCumulativeCardinality(parentOp);
+        if (cardinality == null) {
+          // We could not get stats, we cannot convert
+          return -1;
+        }
+        currentInputCumulativeCardinality = cardinality;
+      }
 
       // This input is the big table if it is contained in the big candidates set, and either:
       // 1) we have not chosen a big table yet, or
       // 2) it has been chosen as the big table above, or
-      // 3) the number of costly operators for this input is higher, or
-      // 4) the number of costly operators is equal, but the size is bigger,
+      // 3) the cumulative cardinality for this input is higher, or
+      // 4) the cumulative cardinality is equal, but the size is bigger,
       boolean selectedBigTable = bigTableCandidateSet.contains(pos) &&
               (bigInputStat == null || currentInputNotFittingInMemory ||
-                      (!foundInputNotFittingInMemory && (currentInputNumberCostlyOps > bigInputNumberCostlyOps ||
-                              (currentInputNumberCostlyOps == bigInputNumberCostlyOps && inputSize > bigInputStat.getDataSize()))));
+                      (!foundInputNotFittingInMemory && (currentInputCumulativeCardinality > bigInputCumulativeCardinality ||
+                              (currentInputCumulativeCardinality == bigInputCumulativeCardinality && inputSize > bigInputStat.getDataSize()))));
 
       if (bigInputStat != null && selectedBigTable) {
         // We are replacing the current big table with a new one, thus
@@ -633,7 +629,7 @@
public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c
 
       if (selectedBigTable) {
         bigTablePosition = pos;
-        bigInputNumberCostlyOps = currentInputNumberCostlyOps;
+        bigInputCumulativeCardinality = currentInputCumulativeCardinality;
         bigInputStat = currInputStat;
       }
 
@@ -642,6 +638,39 @@
public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c
     return bigTablePosition;
   }
 
+  // This is akin to CBO cumulative cardinality model
+  private static Long computeCumulativeCardinality(Operator<? extends OperatorDesc> op) {
+    long cumulativeCardinality = 0L;
+    if (op instanceof CommonJoinOperator) {
+      // Choose max
+      for (Operator<? extends OperatorDesc> inputOp : op.getParentOperators()) {
+        Long inputCardinality = computeCumulativeCardinality(inputOp);
+        if (inputCardinality == null) {
+          return null;
+        }
+        if (inputCardinality > cumulativeCardinality) {
+          cumulativeCardinality = inputCardinality;
+        }
+      }
+    } else {
+      // Choose cumulative
+      for (Operator<? extends OperatorDesc> inputOp : op.getParentOperators()) {
+        Long inputCardinality = computeCumulativeCardinality(inputOp);
+        if (inputCardinality == null) {
+          return null;
+        }
+        cumulativeCardinality += inputCardinality;
+      }
+    }
+    Statistics currInputStat = op.getStatistics();
+    if (currInputStat == null) {
+      LOG.warn("Couldn't get statistics from: " + op);
+      return null;
+    }
+    cumulativeCardinality += currInputStat.getNumRows();
+    return cumulativeCardinality;
+  }
+
   /*
    * Once we have decided on the map join, the tree would transform from
    *
